from typing import Any, List, Tuple, Dict
import pandas as pd
from sklearn.metrics import accuracy_score, precision_score, recall_score, roc_auc_score
from xgboost import XGBClassifier
import copy
import numpy as np
from sklearn.linear_model import LogisticRegression
from xgboost import XGBClassifier
from dataclasses import dataclass, field


@dataclass
class DataToBinsMap:
    """
    Helper class containing the quantiles of the training set, the mapping from
    the quantiles bins to bin classifier, the amount of data in each combined
    bin, and the binary thresholds for classifier of each bin.
    """

    quantiles: pd.DataFrame = None  # pandas dataframe
    quantiles_to_bin: dict = field(default_factory=dict)
    bin_training_data: dict = field(default_factory=dict)
    combined_bins_lengths: dict = field(default_factory=dict)
    bin_thresholds: dict = field(default_factory=dict)
    bin_feature_importances: dict = field(default_factory=dict)


class ConstantModel:
    """
    Default model to fall back to when bin does not have a classifier.
    """

    def predict_proba(self, dummy_input: np.ndarray) -> np.ndarray:
        return np.ones((1, 2)) * 0.5


class LRBinsModel:
    """
    Binary classification model using individual logistic regression models
    for n_models of the combined bins.
    A combined bin is a unique combination of all of the feature bins,
    e.g. if there are 3 features that each are split in 2, then
    (0,0,0), (0,0,1), (0, 1, 0), ..., (1,1,1) are the 2^3=8 combined bins.
    This class has functions useful for operating on the bins.
    When the model sees a data row whose
    bin has no model, it falls back to the fallback_model to decide.
    n_bin_features is how many features are used to put the data into bins.
    n_inference_features are how many features are used in each bin's
    logistic regression model. n_bins_per_feature is how many bins each
    feature is divided into.
    """

    def __init__(
        self,
        fallback_model: Any = None,
        n_bin_features: int = 7,
        n_inference_features: int = 20,
        n_bins_per_feature: int = 2,
        n_models: int = 50,
        feature_importances: List[float] = None,
        default_threshold: float = 0.5,
        inference_on_all_bins: bool = True,
        first_stage_threshold: float = 0.006,
        xgb_model: Any = None,
        get_bin_feature_importances: bool = False,
        edge_interval_bounds: str = "inclusive",
        sort_with_metric: str = "accuracy",
    ):
        self.n_bin_features = n_bin_features
        self.n_inference_features = n_inference_features
        self.n_bins_per_feature = n_bins_per_feature
        self.n_models = n_models
        if fallback_model is not None:
            self.fallback_model = fallback_model
        else:
            self.fallback_model = ConstantModel()
        self.coverage = -1.0
        self.data_to_bins_map = DataToBinsMap()
        self.feature_importances = feature_importances
        self.default_threshold = default_threshold
        self.inference_on_all_bins = inference_on_all_bins
        self.inference_bins = None
        self.xgb_model = xgb_model
        self.first_stage_threshold = first_stage_threshold
        self.get_bin_feature_importances = get_bin_feature_importances
        self.edge_interval_bounds = edge_interval_bounds
        self.thresholds = None
        self.sort_with_metric = sort_with_metric

    def predict_proba_one_row(
        self, X_bin_row: np.ndarray, X_inference_row: np.ndarray, full_X_row: np.ndarray
    ) -> float:
        """
        Predict probability the classes according to a single row of X.
        """
        value = self.data_to_bins_map.quantiles_to_bin.get(tuple(X_bin_row), None)
        if value is None:
            prob = self.fallback_model.predict_proba(full_X_row.reshape(1, -1))[0, 1]
        else:
            self.coverage += 1
            data_mean, data_std, eps, weights = value
            X_inference_row = (X_inference_row - np.array(data_mean)) / (
                np.array(data_std) + eps
            )
            # compute the logistic regression equation
            X_inference_row_with_bias = np.hstack((1.0, X_inference_row))
            z = np.array(weights).dot(X_inference_row_with_bias)
            prob = 1 / (1 + np.exp(-z))
        return prob

    def predict_proba(self, X_test: np.ndarray) -> np.ndarray:
        """
        Predict probabilities of the classes according to X_test.
        """
        # get important features
        bin_X_test = self.get_important_features(
            X_test, self.feature_importances, self.n_bin_features
        )
        inference_X_test = self.get_important_features(
            X_test, self.feature_importances, self.n_inference_features
        )

        # get the combined bins for the test data
        bin_X_test_quantiles = self.get_combined_bins_of_data(bin_X_test)

        # evaluate the corresponding logistic regression model based on the combined bin
        self.coverage = 0
        probs = []
        for X_bin_row, X_inference_row, full_X_row in zip(
            bin_X_test_quantiles, inference_X_test, X_test
        ):
            prob = self.predict_proba_one_row(X_bin_row, X_inference_row, full_X_row)
            probs.append(prob)
        probs = np.expand_dims(np.array(probs), 1)
        self.coverage = self.coverage / len(probs)
        return np.concatenate((np.zeros_like(probs), probs), axis=-1)

    def train_model(self, X_train: np.ndarray, y_train: np.ndarray):
        """
        Use the training data to construct the
        logistic regression with bins model.
        This is done in the following way. First we get the n_bin_features most important features
        and determine the quantiles of these features which are used to make the combined bins.
        For more information on combined bins, see the class comment.
        Then the data is sorted into combined bins. For each bin, a logistic regression model
        is trained and the weights are stored.
        We are free to make predictions on any of these combined bins later using the per-bin 
        logistic regression weights although, in practice, we only do this on the best performing bins.
        """
        # get important features
        bin_X_train = self.get_important_features(
            X_train, self.feature_importances, self.n_bin_features
        )
        inference_X_train = self.get_important_features(
            X_train, self.feature_importances, self.n_inference_features
        )

        # get data quantiles to determine bins using the n_bin_features
        self.set_combined_bins_of_data(bin_X_train, self.n_bins_per_feature)
        bin_X_train_quantiles = self.get_combined_bins_of_data(bin_X_train)

        # put data into combined bins
        for X_bin_row, X_inference_row, y_row in zip(
            bin_X_train_quantiles, inference_X_train, y_train
        ):
            data, labels = self.data_to_bins_map.bin_training_data.get(
                tuple(X_bin_row), ([], [])
            )
            data.append(X_inference_row)
            labels.append(y_row)
            self.data_to_bins_map.bin_training_data[tuple(X_bin_row)] = (data, labels)

        # get bin data info
        for bin in self.data_to_bins_map.bin_training_data.keys():
            data, labels = self.data_to_bins_map.bin_training_data[bin]
            self.data_to_bins_map.combined_bins_lengths[bin] = (
                np.sum(labels),
                len(labels),
            )
            if self.get_bin_feature_importances:
                data = np.array(data)
                labels = np.array(labels)
                clf = XGBClassifier()
                clf.fit(data, labels)
                local_feature_importances = clf.feature_importances_
                self.data_to_bins_map.bin_feature_importances[
                    bin
                ] = local_feature_importances

        # train a logistic regression model for each combined bin and store the weights
        for key in self.data_to_bins_map.bin_training_data.keys():
            data, labels = self.data_to_bins_map.bin_training_data[key]
            data = np.array(data)
            labels = np.array(labels)
            if len(np.unique(labels)) == 2:
                clf = LogisticRegression()
                data, data_mean, data_std, eps = self.normalize_data(data)
                clf.fit(data, labels)
                self.data_to_bins_map.quantiles_to_bin[key] = (
                    list(data_mean),
                    list(data_std),
                    eps,
                    list(np.hstack((clf.intercept_[:, None], clf.coef_))[0]),
                )
            else:
                self.data_to_bins_map.quantiles_to_bin[key] = None

    def normalize_data(
        self, X: np.ndarray
    ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, float]:
        """
        Normalize the data about its means and dividing by standard deviation.
        """
        X_pd = pd.DataFrame(X)
        X_mean = np.array(X_pd.mean())
        X_std = np.array(X_pd.std())
        eps = 1e-4
        normalized_X = (X - X_mean) / (X_std + eps)
        return normalized_X, X_mean, X_std, eps

    def train_xgb_model(self, X_train: np.ndarray, y_train: np.ndarray):
        """
        Train a XGBoost classifier and extract its feature importances.
        """
        if self.xgb_model is not None:
            self.feature_importances = self.xgb_model.feature_importances_
        else:
            clf = XGBClassifier()
            clf.fit(X_train, y_train)
            self.xgb_model = clf
            if self.feature_importances is None:
                self.feature_importances = self.xgb_model.feature_importances_

    def get_important_features(
        self, X: np.ndarray, feature_importances: List[float], n_features: int
    ) -> np.ndarray:
        """
        Return a subset of the data X containing the n_features most important
        features in each row according to feature_importances.
        """
        feature_indices = np.argpartition(feature_importances, -n_features)[
            -n_features:
        ]
        X_subset = X[:, feature_indices]
        return X_subset

    def set_combined_bins_of_data(self, X: np.ndarray, n_bins_per_feature: int):
        """
        Sets the quantile array of the data dictating the combined bins.
        """
        X_pd = pd.DataFrame(X)
        self.data_to_bins_map.quantiles = X_pd.quantile(
            np.linspace(0, 1, n_bins_per_feature + 1)
        )

    def get_combined_bins_of_data(self, X: np.ndarray) -> np.ndarray:
        """
        Get the combined bins of the data X.
        Suppose X has m data rows and n features, and that
        set_combined_bins_of_data made data_to_bins_map.quantiles a
        (n, n_bins_per_feature) array. Then X_quantiles is a
        m x n x n_bins_per_feature array of bools indicating if data point at
        (m,n) is greater than the quantile at (n, n_bins_per_feature).
        Summing over the n_bins_per_feature axis (as is done in X_bins) gives
        the number of quantiles that data point (m,n) is greater than.
        This means each row of X_bins is a list of n integers corresponding to
        a bin for each feature which gives a combined bin that a python
        dictionary can hash when cast to a Tuple. Each row of X corresponds to
        a combined bin row of X_bins.
        """
        X_pd = pd.DataFrame(X)
        quants = np.array(self.data_to_bins_map.quantiles.T)

        if self.edge_interval_bounds == "unbounded":
            X_quantiles = np.expand_dims(X_pd, axis=-1) > np.expand_dims(
                quants[:, 1:-1], axis=0
            )
        elif self.edge_interval_bounds == "inclusive":
            X_quantiles = np.expand_dims(X_pd, axis=-1) > np.expand_dims(quants, axis=0)
            X_quantiles_equal = np.expand_dims(X_pd, axis=-1) == np.expand_dims(
                quants, axis=0
            )
            X_quantiles[..., 0] = np.logical_or(
                X_quantiles[..., 0], X_quantiles_equal[..., 0]
            )
        elif self.edge_interval_bounds == "exclusive":
            X_quantiles = np.expand_dims(X_pd, axis=-1) > np.expand_dims(quants, axis=0)
            X_quantiles_equal = np.expand_dims(X_pd, axis=-1) == np.expand_dims(
                quants, axis=0
            )
            X_quantiles = np.logical_or(
                X_quantiles, np.expand_dims(X_quantiles_equal[..., -1], -1)
            )
        else:
            raise ValueError("Invalid edge interval option")

        X_bins = np.sum(X_quantiles, axis=2)
        return X_bins

    def predict(self, X_test: np.ndarray) -> List[float]:
        """
        Get a list of class predictions for X_test.
        """
        # get important features
        bin_X_test = self.get_important_features(
            X_test, self.feature_importances, self.n_bin_features
        )
        # get bins for test data
        X_test_bins = self.get_combined_bins_of_data(bin_X_test)
        y_probs = self.predict_proba(X_test)[:, 1]
        y_preds = y_probs >= self.default_threshold
        return y_preds

    def get_inference_bins(
        self, X: np.ndarray, y: np.ndarray
    ) -> Tuple[List[Any], float]:
        """
        Using the data in `X` and the labels in `y`, compute the rocauc scores of
        each of the combined bins, then sort them so that when they are accumulated
        we can outperform xgboost by as much as possible. Return the bins in this
        accumulation so that we can determine which bins to use first-stage
        inference on.
        """
        # get important features
        bin_X = self.get_important_features(
            X, self.feature_importances, self.n_bin_features
        )
        inference_X = self.get_important_features(
            X, self.feature_importances, self.n_inference_features
        )
        # store the label, the lrbins prob, and the xgb prob in combined bins
        bin_X_quantiles = self.get_combined_bins_of_data(bin_X)
        bins_to_data = {}
        for X_bin_row, X_inference_row, full_X_row, y_true in zip(
            bin_X_quantiles, inference_X, X, y
        ):
            lrbins_prob = self.predict_proba_one_row(
                X_bin_row, X_inference_row, full_X_row
            )
            xgb_prob = self.xgb_model.predict_proba(full_X_row.reshape(1, -1))[:, 1][0]
            lrbins_pred = int(lrbins_prob >= self.default_threshold)
            xgb_pred = self.xgb_model.predict(full_X_row.reshape(1, -1))[0]
            probs_and_true = bins_to_data.get(tuple(X_bin_row), [])
            probs_and_true.append(
                [y_true, lrbins_prob, xgb_prob, lrbins_pred, xgb_pred]
            )
            bins_to_data[tuple(X_bin_row)] = probs_and_true

        # evaluate rocauc on the combined bins
        lrbins_combined_bin_score = {}
        xgb_combined_bin_score = {}
        for key, probs_and_true in bins_to_data.items():
            probs_and_true = np.array(probs_and_true)
            y_trues = probs_and_true[:, 0]
            y_lrbins = probs_and_true[:, 1]
            y_xgb = probs_and_true[:, 2]
            y_lrbins_preds = probs_and_true[:, 3]
            y_xgb_preds = probs_and_true[:, 4]
            if len(np.unique(y_trues)) == 2:
                if self.sort_with_metric == "rocauc":
                    lrbins_combined_bin_score[key] = roc_auc_score(y_trues, y_lrbins)
                    xgb_combined_bin_score[key] = roc_auc_score(y_trues, y_xgb)
                elif self.sort_with_metric == "accuracy":
                    lrbins_combined_bin_score[key] = accuracy_score(
                        y_trues, y_lrbins_preds
                    )
                    xgb_combined_bin_score[key] = accuracy_score(y_trues, y_xgb_preds)
            else:
                lrbins_combined_bin_score[key] = 0.0
                xgb_combined_bin_score[key] = 0.0

        all_bins = list(lrbins_combined_bin_score.keys())
        lrbins_bin_scores = []
        xgb_bin_scores = []
        for key in all_bins:
            lrbins_bin_scores.append(lrbins_combined_bin_score[key])
            xgb_bin_scores.append(xgb_combined_bin_score[key])

        # sort by performance
        all_bins = np.array(all_bins)
        lrbins_bin_scores = np.array(lrbins_bin_scores)
        xgb_bin_scores = np.array(xgb_bin_scores)
        sort_indices = (xgb_bin_scores - lrbins_bin_scores).argsort()
        all_bins = all_bins[sort_indices]
        lrbins_bin_scores = lrbins_bin_scores[sort_indices]
        xgb_bin_scores = xgb_bin_scores[sort_indices]

        # construct cumulative bins and rocaucs
        cumulative_y_trues = []
        cumulative_y_lrbins = []
        cumulative_y_xgbs = []
        cumulative_y_lrbins_preds = []
        cumulative_y_xgbs_preds = []
        cumulative_lrbins_score = []
        cumulative_xgb_score = []
        fraction_of_data = []
        for key in all_bins:
            probs_and_true = bins_to_data[tuple(key)]
            probs_and_true = np.array(probs_and_true)
            y_trues = probs_and_true[:, 0]
            y_lrbins = probs_and_true[:, 1]
            y_xgb = probs_and_true[:, 2]
            y_lrbins_preds = probs_and_true[:, 3]
            y_xgb_preds = probs_and_true[:, 4]
            cumulative_y_trues += list(y_trues)
            cumulative_y_lrbins += list(y_lrbins)
            cumulative_y_xgbs += list(y_xgb)
            cumulative_y_lrbins_preds += list(y_lrbins_preds)
            cumulative_y_xgbs_preds += list(y_xgb_preds)

            if len(np.unique(cumulative_y_trues)) == 2:
                if self.sort_with_metric == "rocauc":
                    lrbins_roc_auc_score = roc_auc_score(
                        cumulative_y_trues, cumulative_y_lrbins
                    )
                    xgb_roc_auc_score = roc_auc_score(
                        cumulative_y_trues, cumulative_y_xgbs
                    )
                if self.sort_with_metric == "accuracy":
                    lrbins_roc_auc_score = accuracy_score(
                        cumulative_y_trues, cumulative_y_lrbins_preds
                    )
                    xgb_roc_auc_score = accuracy_score(
                        cumulative_y_trues, cumulative_y_xgbs_preds
                    )
            else:
                lrbins_roc_auc_score = 0.0
                xgb_roc_auc_score = 0.0

            cumulative_lrbins_score.append(lrbins_roc_auc_score)
            cumulative_xgb_score.append(xgb_roc_auc_score)
            fraction_of_data.append(len(cumulative_y_trues) / len(X))

        # pick the bins which cumulatively have a LRBins rocauc within `self.first_stage_threshold` of xgboost
        xgb_lrbins_cumulative_diff = np.array(cumulative_xgb_score) - np.array(
            cumulative_lrbins_score
        )
        self.thresholds = xgb_lrbins_cumulative_diff
        bins = all_bins[
            np.argwhere(xgb_lrbins_cumulative_diff <= self.first_stage_threshold)
        ]
        if self.first_stage_threshold < xgb_lrbins_cumulative_diff[0]:
            first_stage_coverage = 0.0
        else:
            first_stage_coverage = fraction_of_data[
                np.argwhere(xgb_lrbins_cumulative_diff <= self.first_stage_threshold)[
                    -1, 0
                ]
            ]
        if len(bins) != 0:
            bins = bins.reshape(-1, bins.shape[-1])
            bins = list(map(tuple, bins))
        return bins, first_stage_coverage

    def filter_bad_bins(self):
        """
        Remove the models from combined bins not in `self.inference_bins`.
        """
        for key in self.data_to_bins_map.quantiles_to_bin.keys():
            if key not in self.inference_bins:
                self.data_to_bins_map.quantiles_to_bin[key] = None

    def fit(
        self,
        X_train: np.ndarray,
        y_train: np.ndarray,
        X_eval: np.ndarray = None,
        y_eval: np.ndarray = None,
    ):
        """
        Fit the model to the training data.
        """
        if self.feature_importances is None or self.xgb_model is None:
            self.train_xgb_model(X_train, y_train)
        self.train_model(X_train, y_train)
        if not self.inference_on_all_bins:
            if X_eval is None or y_eval is None:
                raise ValueError("Dataset to sort bins must be passed to fit.")
            self.inference_bins, self.first_stage_coverage = self.get_inference_bins(
                X_eval, y_eval
            )
            self.copy_of_data_to_bins_map = copy.deepcopy(self.data_to_bins_map)
            self.filter_bad_bins()

    def performance(self, X_test: np.ndarray, y_test: np.ndarray) -> Dict[str, float]:
        """
        Evaluate the model according to several metrics.
        """
        metrics = {}
        y_probs = self.predict_proba(X_test)[:, 1]
        metrics["rocauc"] = roc_auc_score(y_test, y_probs)
        y_preds = self.predict(X_test)
        metrics["precision"] = precision_score(y_test, y_preds)
        metrics["recall"] = recall_score(y_test, y_preds)
        metrics["accuracy"] = accuracy_score(y_test, y_preds)
        metrics["positive predicion percentage"] = np.sum(y_preds) / len(y_preds)
        metrics["coverage"] = self.coverage
        return metrics
